from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from pymilvus import MilvusClient, AnnSearchRequest, RRFRanker, RRFRanker

from app.base.logger import setup_logger

# 使用新的日志配置
logger = setup_logger("milvus_client")


@dataclass
class MilvusConnectionConfig:
    """Milvus连接配置数据类"""
    uri: str
    token: Optional[str] = None
    collection_name: str = "default_collection"
    dimension: Optional[int] = None


@dataclass
class MilvusFieldConfig:
    """Milvus字段配置 - 简单直接的字段映射"""
    id_field: str
    dense_vector_field: Optional[str] = None
    sparse_vector_field: Optional[str] = None
    output_fields: List[str] = None  # 默认输出字段列表

    def __post_init__(self):
        if self.output_fields is None:
            self.output_fields = [self.id_field]


class MilvusService:
    """
    Milvus向量数据库服务类，提供向量检索和查询功能

    要求上游传入连接配置和字段配置
    """

    def __init__(self,
                 connection_config: MilvusConnectionConfig,
                 field_config: MilvusFieldConfig):
        """
        初始化Milvus服务

        Args:
            connection_config: Milvus连接配置（必需）
            field_config: 字段配置（必需）
        """
        self.client = None
        self.collection_name = None
        self.connected = False
        self.error_message = None

        # 设置配置
        self.connection_config = connection_config
        self.field_config = field_config

    def connect(self):
        """
        连接到Milvus向量数据库

        Returns:
            bool: 连接是否成功
        """
        try:
            # 使用配置对象中的参数
            uri = self.connection_config.uri
            token = self.connection_config.token
            collection_name = self.connection_config.collection_name

            logger.info(f"正在连接到Milvus: {uri}, 集合: {collection_name}")

            # 连接到Milvus
            if token:
                self.client = MilvusClient(uri=uri, token=token)
            else:
                self.client = MilvusClient(uri=uri)

            # 检查集合是否存在
            collections = self.client.list_collections()
            if collection_name not in collections:
                logger.error(f"集合 '{collection_name}' 不存在")
                self.error_message = f"集合 '{collection_name}' 不存在"
                return False

            self.collection_name = collection_name
            self.connected = True
            self.error_message = None

            logger.info(f"成功连接到Milvus: 集合 {collection_name}")
            return True

        except Exception as e:
            logger.exception(f"连接Milvus失败: {str(e)}")
            self.error_message = f"连接异常: {str(e)}"
            self.client = None
            self.connected = False
            return False

    def disconnect(self):
        """断开Milvus连接"""
        if self.client:
            logger.info("断开Milvus数据库连接")
            self.client = None
            self.connected = False

    def search_by_vector(self, vector, limit=10, filter_expr=None, output_fields=None):
        """
        根据向量搜索最相似的节点

        Args:
            vector: 输入向量
            limit: 返回结果数量
            filter_expr: 过滤表达式
            output_fields: 输出字段列表

        Returns:
            搜索结果或None
        """
        if not self.connected or not self.client:
            logger.error("未连接到Milvus，无法执行搜索")
            return None

        try:
            # 设置默认输出字段
            if output_fields is None:
                output_fields = self.field_config.output_fields

            # 执行向量搜索
            search_params = {
                "metric_type": "COSINE",
                "params": {"nprobe": 10},
                "hints": "iterative_filter"
            }

            logger.debug(f"执行向量搜索，集合: {self.collection_name}, 限制: {limit}, 过滤: {filter_expr}")

            # 检查是否配置了稠密向量字段
            if not self.field_config.dense_vector_field:
                logger.error("未配置稠密向量字段，无法执行向量搜索")
                return None

            results = self.client.search(
                collection_name=self.collection_name,
                data=[vector],
                anns_field=self.field_config.dense_vector_field,
                limit=limit,
                output_fields=output_fields,
                filter=filter_expr
                # search_params=search_params
            )

            logger.info(f"向量搜索完成，结果数: {len(results) if results else 0}")
            return results

        except Exception as e:
            logger.exception(f"向量搜索失败: {str(e)}")
            return None

    def hybrid_search(self, query_text, query_dense_vector, limit=10, filter_expr=None, output_fields=None,
                      distance_threshold=None, max_distance=None):
        """
        执行混合检索，结合稠密向量和稀疏文本搜索

        Args:
            query_text: 查询文本
            query_dense_vector: 查询的稠密向量
            limit: 每个搜索请求返回的结果数量
            filter_expr: 过滤表达式
            output_fields: 输出字段列表
            distance_threshold: 距离阈值，用于过滤相似度过低的结果
            max_distance: 最大距离，用于过滤相似度过高的结果

        Returns:
            混合搜索结果或None
        """
        if not self.connected or not self.client:
            logger.error("未连接到Milvus，无法执行混合搜索")
            return None

        try:
            # 设置默认输出字段
            if output_fields is None:
                output_fields = self.field_config.output_fields

            # 构建稠密向量搜索请求的参数
            dense_param = {"nprobe": 10}

            # 添加距离过滤参数（适用于COSINE距离）
            if distance_threshold is not None:
                dense_param["radius"] = distance_threshold
            if max_distance is not None:
                dense_param["range_filter"] = max_distance

            # 检查稠密向量字段配置
            if not self.field_config.dense_vector_field:
                logger.error("未配置稠密向量字段，无法执行混合搜索")
                return None

            search_param_1 = {
                "data": [query_dense_vector],
                "anns_field": self.field_config.dense_vector_field,
                "param": dense_param,
                "limit": limit,
                "expr": filter_expr,
                "expr_params": {"hints": "iterative_filter"}
            }
            request_1 = AnnSearchRequest(**search_param_1)

            # 构建稀疏文本搜索请求的参数
            sparse_param = {"drop_ratio_search": 0.2}

            # 稀疏向量搜索也可以添加距离过滤
            if distance_threshold is not None:
                sparse_param["radius"] = distance_threshold
            if max_distance is not None:
                sparse_param["range_filter"] = max_distance

            # 检查稀疏向量字段配置
            if not self.field_config.sparse_vector_field:
                logger.error("未配置稀疏向量字段，无法执行混合搜索")
                return None

            search_param_2 = {
                "data": [query_text],
                "anns_field": self.field_config.sparse_vector_field,
                "param": sparse_param,
                "limit": limit,
                "expr": filter_expr,
                "expr_params": {"hints": "iterative_filter"}
            }
            request_2 = AnnSearchRequest(**search_param_2)

            # 组合搜索请求
            reqs = [request_1, request_2]

            logger.debug(f"执行混合搜索，集合: {self.collection_name}, 限制: {limit}, 过滤: {filter_expr}")

            ranker = RRFRanker(100)

            # 执行混合搜索
            results = self.client.hybrid_search(
                collection_name=self.collection_name,
                reqs=reqs,
                ranker=ranker,
                limit=limit,
                output_fields=output_fields
            )

            logger.info(f"混合搜索完成，结果数: {len(results) if results else 0}")
            return results

        except Exception as e:
            logger.error(f"混合搜索失败: {type(e).__name__}: {str(e)}")
            logger.exception("混合搜索详细错误信息:")
            # 如果混合搜索失败，回退到稠密向量搜索
            logger.warning("混合搜索失败，回退到稠密向量搜索")
            try:
                return self.search_by_vector(
                    vector=query_dense_vector,
                    limit=limit,
                    filter_expr=filter_expr,
                    output_fields=output_fields
                )
            except Exception as fallback_error:
                logger.error(f"回退搜索也失败: {type(fallback_error).__name__}: {str(fallback_error)}")
                return None

    def search_by_text(self, search_text, node_types=None, limit=10, embedding_function=None):
        """
        根据文本搜索最相似的节点
        
        Args:
            search_text: 搜索文本
            node_types: 节点类型列表，用于过滤结果
            limit: 返回结果数量
            embedding_function: 嵌入函数，将文本转换为向量
            
        Returns:
            搜索结果或None
        """
        if not self.connected or not self.client:
            logger.error("未连接到Milvus，无法执行搜索")
            return None

        try:
            # 确保提供了嵌入函数
            if embedding_function is None:
                logger.error("未提供嵌入函数，无法将文本转换为向量")
                return None

            # 将搜索文本转换为向量
            vector = embedding_function(search_text)

            # 构建过滤表达式
            filter_expr = None
            if node_types and len(node_types) > 0:
                node_types_str = ",".join([f"'{t}'" for t in node_types])
                filter_expr = f"node_type in [{node_types_str}]"

            # 执行向量搜索
            return self.search_by_vector(
                vector=vector,
                limit=limit,
                filter_expr=filter_expr
            )

        except Exception as e:
            logger.exception(f"文本搜索失败: {str(e)}")
            return None

    def get_error_message(self):
        """获取最后一次错误信息"""
        return getattr(self, 'error_message', "未知错误")
